# SPDX-FileCopyrightText: Copyright (c) 2022-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import configparser
import time
from operator import attrgetter
from pathlib import Path
from typing import Dict, List, Optional, Union

import numpy as np
import tensorrt_llm
import tensorrt_llm.logger as logger
import torch
from safetensors import safe_open
from tensorrt_llm._utils import str_dtype_to_torch, torch_to_numpy
from tensorrt_llm.mapping import Mapping
from tensorrt_llm.models import LLaMAForCausalLM
from tensorrt_llm.models.quantized.quant import get_dummy_quant_scales
from tensorrt_llm.quantization import QuantMode


def get_scaling_factors(
    model_path: Union[str, Path],
    num_layers: int,
    quant_mode: Optional[QuantMode] = None,
) -> Optional[Dict[str, List[int]]]:
    """Get the scaling factors for LLaMA model

    Returns a dictionary of scaling factors for the selected layers of the
    LLaMA model.

    Args:
        model_path (str): Path to the quantized LLaMA model
        layers (list): List of layers to get the scaling factors for. If None,
            all layers are selected.

    Returns:
        dict: Dictionary of scaling factors for the selected layers of the
        LLaMA model.

        example:

        {
            'qkv_act': qkv_act_scale,
            'qkv_weights': qkv_weights_scale,
            'qkv_output' : qkv_outputs_scale,
            'dense_act': dense_act_scale,
            'dense_weights': dense_weights_scale,
            'fc_act': fc_act_scale,
            'fc_weights': fc_weights_scale,
            'gate_act': gate_act_scale,
            'gate_weights': gate_weights_scale,
            'proj_act': proj_act_scale,
            'proj_weights': proj_weights_scale,
        }
    """

    if model_path is None:
        logger.warning(
            f"--quantized_fp8_model_path not specified. "
            f"Initialize quantization scales automatically."
        )
        return get_dummy_quant_scales(num_layers)
    weight_dict = np.load(model_path)

    # yapf: disable
    scaling_factor = {
        'qkv_act': [],
        'qkv_weights': [],
        'qkv_output': [],
        'dense_act': [],
        'dense_weights': [],
        'fc_act': [],
        'fc_weights': [],
        'gate_act': [],
        'gate_weights': [],
        'proj_act': [],
        'proj_weights': [],
    }

    for layer in range(num_layers):
        scaling_factor['qkv_act'].append(max(
            weight_dict[f'_np:layers:{layer}:attention:qkv:q:activation_scaling_factor'].item(),
            weight_dict[f'_np:layers:{layer}:attention:qkv:k:activation_scaling_factor'].item(),
            weight_dict[f'_np:layers:{layer}:attention:qkv:v:activation_scaling_factor'].item()
            ))
        scaling_factor['qkv_weights'].append(max(
            weight_dict[f'_np:layers:{layer}:attention:qkv:q:weights_scaling_factor'].item(),
            weight_dict[f'_np:layers:{layer}:attention:qkv:k:weights_scaling_factor'].item(),
            weight_dict[f'_np:layers:{layer}:attention:qkv:v:weights_scaling_factor'].item()
            ))
        if quant_mode is not None and quant_mode.has_fp8_kv_cache():
            # Not calibrarting KV cache.
            scaling_factor['qkv_output'].append(1.0)
        scaling_factor['dense_act'].append(weight_dict[f'_np:layers:{layer}:attention:dense:activation_scaling_factor'].item())
        scaling_factor['dense_weights'].append(weight_dict[f'_np:layers:{layer}:attention:dense:weights_scaling_factor'].item())
        scaling_factor['fc_act'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:activation_scaling_factor'].item())
        scaling_factor['fc_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:fc:weights_scaling_factor'].item())
        scaling_factor['gate_act'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:activation_scaling_factor'].item())
        scaling_factor['gate_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:gate:weights_scaling_factor'].item())
        scaling_factor['proj_act'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:activation_scaling_factor'].item())
        scaling_factor['proj_weights'].append(weight_dict[f'_np:layers:{layer}:mlp:proj:weights_scaling_factor'].item())
    # yapf: enable
    for k, v in scaling_factor.items():
        assert (
            len(v) == num_layers
        ), f"Expect scaling factor {k} of length {num_layers}, got {len(v)}"

    return scaling_factor


def gen_suffix(rank, use_smooth_quant, quant_per_channel):
    suffix = f"{rank}.bin"
    if use_smooth_quant:
        sq_prefix = "int8."
        if quant_per_channel:
            sq_prefix += "col."
        suffix = sq_prefix + suffix
    return suffix


def extract_layer_idx(name):
    ss = name.split(".")
    for s in ss:
        if s.isdigit():
            return s
    return None


def split(v, tp_size, idx, dim=0):
    if tp_size == 1:
        return v
    if len(v.shape) == 1:
        return np.ascontiguousarray(np.split(v, tp_size)[idx])
    else:
        return np.ascontiguousarray(np.split(v, tp_size, axis=dim)[idx])


def dup_kv_weight(v, num_head, tp_size):
    assert tp_size % num_head == 0
    reps = tp_size // num_head
    head_size = v.shape[0] // num_head
    v = v.reshape(num_head, head_size, -1)[:, None, :, :].expand(
        num_head, reps, head_size, v.shape[1]
    )
    return v.reshape(num_head * reps * head_size, -1).clone()


def parse_ft_config(ini_file):
    gpt_config = configparser.ConfigParser()
    gpt_config.read(ini_file)

    n_embd = gpt_config.getint("llama", "hidden_size")
    n_head = gpt_config.getint("llama", "num_attention_heads")
    n_layer = gpt_config.getint("llama", "num_hidden_layers")
    n_positions = gpt_config.getint("llama", "max_position_embeddings")
    vocab_size = gpt_config.getint("llama", "vocab_size")
    hidden_act = gpt_config.get("llama", "hidden_act")
    inter_size = gpt_config.getint("llama", "intermediate_size", fallback=None)
    n_kv_head = gpt_config.getint("llama", "num_key_value_heads", fallback=None)

    if inter_size is None:
        inter_size = 4 * n_embd

    return (
        n_embd,
        n_head,
        n_layer,
        n_positions,
        vocab_size,
        hidden_act,
        inter_size,
        n_kv_head,
    )


def load_from_hf_llama(
    tensorrt_llm_llama: tensorrt_llm.models.LLaMAForCausalLM,
    hf_llama,
    mapping=Mapping(),
    dtype="float32",
):
    tensorrt_llm.logger.info("Loading weights from HF LLaMA...")
    tik = time.time()

    quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0))
    if quant_mode.is_int8_weight_only():
        plugin_weight_only_quant_type = torch.int8
    elif quant_mode.is_int4_weight_only():
        plugin_weight_only_quant_type = torch.quint4x2
    use_weight_only = quant_mode.is_weight_only()
    num_kv_heads = tensorrt_llm_llama.num_kv_heads
    mha_mode = num_kv_heads == tensorrt_llm_llama.num_heads

    model_params = dict(hf_llama.named_parameters())
    for l in range(hf_llama.config.num_hidden_layers):
        prefix = f"model.layers.{l}.self_attn."
        q_weight = model_params[prefix + "q_proj.weight"]
        k_weight = model_params[prefix + "k_proj.weight"]
        v_weight = model_params[prefix + "v_proj.weight"]
        if not mha_mode:
            head_size = tensorrt_llm_llama.hidden_size // tensorrt_llm_llama.num_heads
            if num_kv_heads < mapping.tp_size:
                # duplicate the KV heads up to tensor_parallel
                k_weight = dup_kv_weight(k_weight, num_kv_heads, mapping.tp_size)
                v_weight = dup_kv_weight(v_weight, num_kv_heads, mapping.tp_size)
            assert (k_weight.shape[0] % (mapping.tp_size * head_size)) == 0
            assert (v_weight.shape[0] % (mapping.tp_size * head_size)) == 0
            qkv_weight = [q_weight, k_weight, v_weight]
        else:
            qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)

        model_params[prefix + "qkv_proj.weight"] = qkv_weight

    torch_dtype = str_dtype_to_torch(dtype)
    layers_per_pipeline_stage = hf_llama.config.num_hidden_layers // mapping.pp_size
    layers_range = list(
        range(
            mapping.pp_rank * layers_per_pipeline_stage,
            (mapping.pp_rank + 1) * layers_per_pipeline_stage,
            1,
        )
    )
    for k, v in model_params.items():
        if isinstance(v, list):
            v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v]
        else:
            v = torch_to_numpy(v.to(torch_dtype).detach().cpu())
        if "model.embed_tokens.weight" in k:
            if tensorrt_llm_llama.use_parallel_embedding:
                v = split(
                    v,
                    mapping.tp_size,
                    mapping.tp_rank,
                    tensorrt_llm_llama.embedding_sharding_dim,
                )
            if mapping.is_first_pp_rank():
                tensorrt_llm_llama.vocab_embedding.weight.value = v
        elif "model.norm.weight" in k:
            if mapping.is_last_pp_rank():
                tensorrt_llm_llama.ln_f.weight.value = v
        elif "lm_head.weight" in k:
            if mapping.is_last_pp_rank():
                tensorrt_llm_llama.lm_head.weight.value = np.ascontiguousarray(
                    split(v, mapping.tp_size, mapping.tp_rank)
                )
        else:
            layer_idx = extract_layer_idx(k)
            if layer_idx is None or int(layer_idx) not in layers_range:
                continue
            idx = int(layer_idx) - mapping.pp_rank * layers_per_pipeline_stage
            if idx >= tensorrt_llm_llama.num_layers:
                continue
            if "input_layernorm.weight" in k:
                tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = v
            elif "post_attention_layernorm.weight" in k:
                dst = tensorrt_llm_llama.layers[idx].post_layernorm.weight
                dst.value = v
            elif "self_attn.qkv_proj.weight" in k:
                dst = tensorrt_llm_llama.layers[idx].attention.qkv.weight
                if not mha_mode:
                    assert isinstance(v, list) and len(v) == 3
                    wq = split(v[0], mapping.tp_size, mapping.tp_rank)
                    wk = split(v[1], mapping.tp_size, mapping.tp_rank)
                    wv = split(v[2], mapping.tp_size, mapping.tp_rank)
                    split_v = np.concatenate((wq, wk, wv))
                else:
                    q_emb = v.shape[0] // 3
                    model_emb = v.shape[1]
                    v = v.reshape(3, q_emb, model_emb)
                    split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
                    split_v = split_v.reshape(3 * (q_emb // mapping.tp_size), model_emb)
                if use_weight_only:
                    v = np.ascontiguousarray(split_v.transpose())
                    (
                        processed_torch_weights,
                        torch_weight_scales,
                    ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
                        torch.tensor(v), plugin_weight_only_quant_type
                    )
                    # workaround for trt not supporting int8 inputs in plugins currently
                    dst.value = processed_torch_weights.view(
                        dtype=torch.float32
                    ).numpy()
                    scales = tensorrt_llm_llama.layers[
                        idx
                    ].attention.qkv.per_channel_scale
                    scales.value = torch_weight_scales.numpy()
                else:
                    dst.value = np.ascontiguousarray(split_v)
            elif "self_attn.o_proj.weight" in k:
                dst = tensorrt_llm_llama.layers[idx].attention.dense.weight
                split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
                if use_weight_only:
                    v = np.ascontiguousarray(split_v.transpose())
                    (
                        processed_torch_weights,
                        torch_weight_scales,
                    ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
                        torch.tensor(v), plugin_weight_only_quant_type
                    )
                    # workaround for trt not supporting int8 inputs in plugins currently
                    dst.value = processed_torch_weights.view(
                        dtype=torch.float32
                    ).numpy()
                    scales = tensorrt_llm_llama.layers[
                        idx
                    ].attention.dense.per_channel_scale
                    scales.value = torch_weight_scales.numpy()
                else:
                    dst.value = np.ascontiguousarray(split_v)
            elif "mlp.up_proj.weight" in k:
                dst = tensorrt_llm_llama.layers[idx].mlp.gate.weight
                split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0)
                if use_weight_only:
                    v = np.ascontiguousarray(split_v.transpose())
                    (
                        processed_torch_weights,
                        torch_weight_scales,
                    ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
                        torch.tensor(v), plugin_weight_only_quant_type
                    )
                    # workaround for trt not supporting int8 inputs in plugins currently
                    dst.value = processed_torch_weights.view(
                        dtype=torch.float32
                    ).numpy()
                    scales = tensorrt_llm_llama.layers[idx].mlp.gate.per_channel_scale
                    scales.value = torch_weight_scales.numpy()
                else:
                    dst.value = np.ascontiguousarray(split_v)
            elif "mlp.down_proj.weight" in k:
                dst = tensorrt_llm_llama.layers[idx].mlp.proj.weight
                split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=1)
                if use_weight_only:
                    v = np.ascontiguousarray(split_v.transpose())
                    (
                        processed_torch_weights,
                        torch_weight_scales,
                    ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
                        torch.tensor(v), plugin_weight_only_quant_type
                    )
                    # workaround for trt not supporting int8 inputs in plugins currently
                    dst.value = processed_torch_weights.view(
                        dtype=torch.float32
                    ).numpy()
                    scales = tensorrt_llm_llama.layers[idx].mlp.proj.per_channel_scale
                    scales.value = torch_weight_scales.numpy()
                else:
                    dst.value = np.ascontiguousarray(split_v)
            elif "mlp.gate_proj.weight" in k:
                dst = tensorrt_llm_llama.layers[idx].mlp.fc.weight
                split_v = split(v, mapping.tp_size, mapping.tp_rank, dim=0)
                if use_weight_only:
                    v = np.ascontiguousarray(split_v.transpose())
                    (
                        processed_torch_weights,
                        torch_weight_scales,
                    ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
                        torch.tensor(v), plugin_weight_only_quant_type
                    )
                    # workaround for trt not supporting int8 inputs in plugins currently
                    dst.value = processed_torch_weights.view(
                        dtype=torch.float32
                    ).numpy()
                    scales = tensorrt_llm_llama.layers[idx].mlp.fc.per_channel_scale
                    scales.value = torch_weight_scales.numpy()
                else:
                    dst.value = np.ascontiguousarray(split_v)

    tok = time.time()
    t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
    tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}")
    return


def load_from_meta_llama(
    tensorrt_llm_llama: tensorrt_llm.models.LLaMAForCausalLM,
    meta_ckpt_dir,
    mapping=Mapping(),
    dtype="float32",
):
    torch_dtype = str_dtype_to_torch(dtype)

    def gather_ckpts(ckpts):
        gathered = {}
        for k in ckpts[0]:
            d = 0
            if any([n in k for n in ["wo", "w2", "tok"]]):
                d = 1
            if "norm" in k or "rope" in k:  # no TP
                gathered[k] = ckpts[0][k].clone()
            else:
                gathered[k] = torch.cat([pt[k] for pt in ckpts], dim=d).clone()
        return gathered

    def split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank):
        split_ckpt = {}
        for k in ckpt:
            d = 0
            if any([n in k for n in ["wo", "w2", "tok"]]):
                d = 1
            if "norm" in k or "rope" in k:  # no TP
                split_ckpt[k] = ckpt[k].clone()
            elif tensorrt_llm_llama.num_kv_heads < mapping.tp_size and any(
                [n in k for n in ["wk", "wv"]]
            ):
                assert mapping.tp_size % tensorrt_llm_llama.num_kv_heads == 0
                # special case: we need to duplicate KV head
                tmp = dup_kv_weight(
                    ckpt[k], tensorrt_llm_llama.num_kv_heads, mapping.tp_size
                )
                split_ckpt[k] = torch.split(tmp, tmp.shape[d] // ranks_per_ckpt, dim=d)[
                    ckpt_rank
                ].clone()
            else:
                split_ckpt[k] = torch.split(
                    ckpt[k], ckpt[k].shape[d] // ranks_per_ckpt, dim=d
                )[ckpt_rank].clone()
        return split_ckpt

    def get_current_weights(num_ckpts):
        if num_ckpts > mapping.tp_size:
            # combine ckpts
            assert (num_ckpts % mapping.tp_size) == 0
            nf = num_ckpts // mapping.tp_size
            fs = nf * mapping.tp_rank
            file_ids = list(range(fs, fs + nf))
            ckpts = []
            for f in file_ids:
                ckpt = torch.load(
                    Path(meta_ckpt_dir, f"consolidated.{f:02d}.pth"), map_location="cpu"
                )
                ckpts.append(ckpt)
            return gather_ckpts(ckpts)
        elif num_ckpts < mapping.tp_size:
            # split ckpt
            assert (mapping.tp_size % num_ckpts) == 0
            ranks_per_ckpt = mapping.tp_size // num_ckpts
            ckpt_fid = mapping.tp_rank // ranks_per_ckpt
            ckpt_rank = mapping.tp_rank % ranks_per_ckpt
            nH_per_ckpt = tensorrt_llm_llama.num_heads // num_ckpts
            assert (nH_per_ckpt % ranks_per_ckpt) == 0
            ckpt = torch.load(
                Path(meta_ckpt_dir, f"consolidated.{ckpt_fid:02d}.pth"),
                map_location="cpu",
            )
            return split_ckpt(ckpt, ranks_per_ckpt, ckpt_rank)

        # num_ckpts == tensor_parallel, 1:1 mapping from files to TP
        return torch.load(
            Path(meta_ckpt_dir, f"consolidated.{mapping.tp_rank:02d}.pth"),
            map_location="cpu",
        )

    def permute(w, nH, d, dH):
        # due to MQA's wk, nH*dH != d could be true
        return w.view(nH, dH // 2, 2, d).transpose(1, 2).reshape(nH * dH, d)

    if not hasattr(load_from_meta_llama, "saved_embed"):
        load_from_meta_llama.saved_embed = None

    def gather_embedding(cur_embed, name: str, num_ckpts):
        if mapping.tp_size == 1:
            # even if num_ckpts > 1, get_current_weights will already have it gathered
            return cur_embed
        if load_from_meta_llama.saved_embed is None:
            embeds = [None] * num_ckpts
            for i in range(num_ckpts):
                ckpt = torch.load(
                    Path(meta_ckpt_dir, f"consolidated.{i:02d}.pth"), map_location="cpu"
                )
                embeds[i] = ckpt[name]
            embed = torch.cat(embeds, dim=1).to(torch_dtype)
            load_from_meta_llama.saved_embed = torch_to_numpy(
                embed
            )  # cache the embedding, not needed if no refit
        return load_from_meta_llama.saved_embed

    tensorrt_llm.logger.info("Loading weights from Meta LLaMA checkpoints ...")
    tik = time.time()

    quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0))
    if quant_mode.is_int8_weight_only():
        torch.int8
    elif quant_mode.is_int4_weight_only():
        torch.quint4x2
    quant_mode.is_weight_only()
    num_kv_heads = tensorrt_llm_llama.num_kv_heads
    mha_mode = num_kv_heads == tensorrt_llm_llama.num_heads

    ckpts = list(Path(meta_ckpt_dir).glob("consolidated.*.pth"))
    num_ckpts = len(ckpts)
    # llama/llama2 doesn't have MQA. So, simplifying loader logic by not worrying about it.
    assert (
        num_kv_heads > 1 or num_kv_heads >= num_ckpts
    ), f"We don't know how the {num_kv_heads} KV heads are distributed among {num_ckpts} checkpoints."

    head_size = tensorrt_llm_llama.hidden_size // tensorrt_llm_llama.num_heads
    ckpt = get_current_weights(num_ckpts)
    layers_range = list(
        range(
            mapping.pp_rank * tensorrt_llm_llama.num_layers,
            (mapping.pp_rank + 1) * tensorrt_llm_llama.num_layers,
            1,
        )
    )

    for l in layers_range:
        prefix = f"layers.{l}.attention."
        q_weight = permute(
            ckpt[prefix + "wq.weight"].clone(),
            nH=(tensorrt_llm_llama.num_heads // mapping.tp_size),
            d=tensorrt_llm_llama.hidden_size,
            dH=head_size,
        )
        if num_kv_heads < mapping.tp_size and num_ckpts >= mapping.tp_size:
            assert mapping.tp_size % num_kv_heads == 0
            assert False, "Not supported yet"
        k_weight = permute(
            ckpt[prefix + "wk.weight"].clone(),
            nH=((num_kv_heads + mapping.tp_size - 1) // mapping.tp_size),
            d=tensorrt_llm_llama.hidden_size,
            dH=head_size,
        )
        v_weight = ckpt[prefix + "wv.weight"].clone()

        qkv_weight = torch.cat([q_weight, k_weight, v_weight], dim=0)
        ckpt[prefix + "qkv.weight"] = qkv_weight

    for k, v in ckpt.items():
        v = torch_to_numpy(v.to(torch_dtype).detach().cpu())
        if "tok_embeddings" in k:
            if not tensorrt_llm_llama.use_parallel_embedding:
                v = gather_embedding(v, k, num_ckpts)
            elif tensorrt_llm_llama.embedding_sharding_dim == 0:
                # this needs a gather and then resplit along different dims
                v = gather_embedding(v, k, num_ckpts)
                v = split(v, mapping.tp_size, mapping.tp_rank, 0)
            if mapping.is_first_pp_rank():
                tensorrt_llm_llama.vocab_embedding.weight.value = v
        elif "output" in k:
            if mapping.is_last_pp_rank():
                tensorrt_llm_llama.lm_head.weight.value = v
        elif k == "norm.weight":
            if mapping.is_last_pp_rank():
                tensorrt_llm_llama.ln_f.weight.value = v
        else:
            # layer specific weights
            layer_idx = extract_layer_idx(k)
            if layer_idx is None:
                continue
            idx = int(layer_idx) - mapping.pp_rank * tensorrt_llm_llama.num_layers
            if idx >= tensorrt_llm_llama.num_layers:
                continue
            if "attention_norm.weight" in k:
                tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = v
            elif "ffn_norm.weight" in k:
                tensorrt_llm_llama.layers[idx].post_layernorm.weight.value = v
            elif "feed_forward.w3.weight" in k:
                tensorrt_llm_llama.layers[idx].mlp.gate.weight.value = v
            elif "feed_forward.w2.weight" in k:
                tensorrt_llm_llama.layers[idx].mlp.proj.weight.value = v
            elif "feed_forward.w1.weight" in k:
                tensorrt_llm_llama.layers[idx].mlp.fc.weight.value = v
            elif "attention.wo.weight" in k:
                tensorrt_llm_llama.layers[idx].attention.dense.weight.value = v
            elif "attention.qkv.weight" in k:
                tensorrt_llm_llama.layers[idx].attention.qkv.weight.value = v

    tok = time.time()
    t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
    tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}")
    return


def load_from_binary(
    tensorrt_llm_llama: LLaMAForCausalLM,
    dir_path,
    mapping=Mapping(),
    fp16=False,
    multi_query_mode=False,
):
    tensorrt_llm.logger.info("Loading weights from FT...")
    tik = time.time()

    quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0))

    (
        n_embd,
        n_head,
        n_layer,
        n_positions,
        vocab_size,
        hidden_act,
        inter_size,
        n_kv_head,
    ) = parse_ft_config(Path(dir_path) / "config.ini")
    np_dtype = np.float16 if fp16 else np.float32

    def fromfile(dir_path, name, shape=None, dtype=None):
        dtype = np_dtype if dtype is None else dtype
        p = dir_path + "/" + name
        if Path(p).exists():
            t = np.fromfile(p, dtype=dtype)
            if shape is not None:
                t = t.reshape(shape)
            return t
        return None

    def set_smoothquant_scale_factors(
        module,
        pre_scale_weight,
        dir_path,
        basename,
        shape,
        per_tok_dyn,
        per_channel,
        is_qkv=False,
        rank=None,
    ):
        suffix = "bin"
        if per_channel:
            if rank is not None:
                suffix = f"{rank}." + suffix
            suffix = "col." + suffix

        col_shape = shape if (per_channel or is_qkv) else [1, 1]

        if per_tok_dyn:
            if pre_scale_weight is not None:
                pre_scale_weight.value = np.array([1.0], dtype=np.float32)
            if is_qkv and not per_channel:
                t = fromfile(
                    dir_path,
                    f"{basename}scale_w_quant_orig.{rank}.{suffix}",
                    col_shape,
                    np.float32,
                )
            else:
                t = fromfile(
                    dir_path,
                    f"{basename}scale_w_quant_orig.{suffix}",
                    col_shape,
                    np.float32,
                )
            module.per_channel_scale.value = t
        else:
            t = fromfile(dir_path, f"{basename}scale_x_orig_quant.bin", [1], np.float32)
            pre_scale_weight.value = t
            if is_qkv:
                t = fromfile(
                    dir_path,
                    f"{basename}scale_y_accum_quant.{rank}.{suffix}",
                    col_shape,
                    np.float32,
                )
            else:
                t = fromfile(
                    dir_path,
                    f"{basename}scale_y_accum_quant.{suffix}",
                    col_shape,
                    np.float32,
                )
            module.per_channel_scale.value = t
            t = fromfile(
                dir_path, f"{basename}scale_y_quant_orig.bin", [1, 1], np.float32
            )
            module.act_scale.value = t

    def set_smoother(module, dir_path, base_name, shape, rank):
        suffix = f"{rank}.bin"
        t = fromfile(dir_path, f"{base_name}.smoother.{suffix}", shape, np.float32)
        module.smoother.value = t

    # Determine the quantization mode.
    quant_mode = getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0))
    if quant_mode.is_int8_weight_only():
        plugin_weight_only_quant_type = torch.int8
    elif quant_mode.is_int4_weight_only():
        plugin_weight_only_quant_type = torch.quint4x2
    # Do we use SmoothQuant?
    use_smooth_quant = quant_mode.has_act_and_weight_quant()
    # Do we use quantization per token?
    quant_per_token_dyn = quant_mode.has_per_token_dynamic_scaling()
    # Do we use quantization per channel?
    quant_per_channel = quant_mode.has_per_channel_scaling()

    # Do we use INT4/INT8 weight-only?
    use_weight_only = quant_mode.is_weight_only()

    # Int8 KV cache
    use_int8_kv_cache = quant_mode.has_int8_kv_cache()

    def sq_trick(x):
        return x.view(np.float32) if use_smooth_quant else x

    # Debug
    suffix = gen_suffix(mapping.tp_rank, use_smooth_quant, quant_per_channel)
    # The type of weights.
    w_type = np_dtype if not use_smooth_quant else np.int8

    if mapping.is_first_pp_rank():
        tensorrt_llm_llama.vocab_embedding.weight.value = fromfile(
            dir_path, "vocab_embedding.weight.bin", [vocab_size, n_embd]
        )

    if mapping.is_last_pp_rank():
        tensorrt_llm_llama.ln_f.weight.value = fromfile(dir_path, "ln_f.weight.bin")
    # share input embedding
    lm_head_weight = fromfile(dir_path, "lm_head.weight.bin", [vocab_size, n_embd])

    if vocab_size % mapping.tp_size != 0:
        # padding
        vocab_size_padded = tensorrt_llm_llama.lm_head.out_features * mapping.tp_size
        pad_width = vocab_size_padded - vocab_size
        lm_head_weight = np.pad(
            lm_head_weight, ((0, pad_width), (0, 0)), "constant", constant_values=0
        )
    if mapping.is_last_pp_rank():
        tensorrt_llm_llama.lm_head.weight.value = np.ascontiguousarray(
            split(lm_head_weight, mapping.tp_size, mapping.tp_rank)
        )

    layers_range = list(
        range(
            mapping.pp_rank * tensorrt_llm_llama.num_layers,
            (mapping.pp_rank + 1) * tensorrt_llm_llama.num_layers,
            1,
        )
    )

    for i in layers_range:
        n_groups = n_head // n_kv_head
        c_attn_out_dim = (
            (3 * n_embd // mapping.tp_size)
            if not multi_query_mode
            else (
                n_embd // mapping.tp_size
                + (n_embd // n_head * n_groups) // mapping.tp_size * 2
            )
        )
        idx = i - mapping.pp_rank * tensorrt_llm_llama.num_layers
        tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = fromfile(
            dir_path, "model.layers." + str(i) + ".input_layernorm.weight.bin"
        )
        t = fromfile(
            dir_path,
            "model.layers." + str(i) + ".attention.query_key_value.weight." + suffix,
            [n_embd, c_attn_out_dim],
            w_type,
        )
        if t is not None:
            dst = tensorrt_llm_llama.layers[idx].attention.qkv.weight
            if use_smooth_quant:
                dst.value = sq_trick(np.ascontiguousarray(np.transpose(t, [1, 0])))
                set_smoothquant_scale_factors(
                    tensorrt_llm_llama.layers[idx].attention.qkv,
                    tensorrt_llm_llama.layers[idx].input_layernorm.scale_to_int,
                    dir_path,
                    "model.layers." + str(i) + ".attention.query_key_value.",
                    [1, c_attn_out_dim],
                    quant_per_token_dyn,
                    quant_per_channel,
                    rank=mapping.tp_rank,
                    is_qkv=True,
                )
            elif use_weight_only:
                (
                    processed_torch_weights,
                    torch_weight_scales,
                ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
                    torch.tensor(t), plugin_weight_only_quant_type
                )
                # workaround for trt not supporting int8 inputs in plugins currently
                dst.value = processed_torch_weights.view(dtype=torch.float32).numpy()
                scales = tensorrt_llm_llama.layers[i].attention.qkv.per_channel_scale
                scales.value = torch_weight_scales.numpy()
            else:
                dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))

        dst = tensorrt_llm_llama.layers[idx].attention.dense.weight
        t = fromfile(
            dir_path,
            "model.layers." + str(i) + ".attention.dense.weight." + suffix,
            [n_embd // mapping.tp_size, n_embd],
            w_type,
        )
        if use_smooth_quant:
            dst.value = sq_trick(np.ascontiguousarray(np.transpose(t, [1, 0])))
            dense_scale = getattr(
                tensorrt_llm_llama.layers[idx].attention,
                "quantization_scaling_factor",
                None,
            )
            set_smoothquant_scale_factors(
                tensorrt_llm_llama.layers[idx].attention.dense,
                dense_scale,
                dir_path,
                "model.layers." + str(i) + ".attention.dense.",
                [1, n_embd],
                quant_per_token_dyn,
                quant_per_channel,
            )
            set_smoother(
                tensorrt_llm_llama.layers[idx].attention.dense,
                dir_path,
                "model.layers." + str(i) + ".attention.dense",
                [1, n_embd // mapping.tp_size],
                mapping.tp_rank,
            )
        elif use_weight_only:
            (
                processed_torch_weights,
                torch_weight_scales,
            ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
                torch.tensor(t), plugin_weight_only_quant_type
            )
            # workaround for trt not supporting int8 inputs in plugins currently
            dst.value = processed_torch_weights.view(dtype=torch.float32).numpy()
            scales = tensorrt_llm_llama.layers[i].attention.dense.per_channel_scale
            scales.value = torch_weight_scales.numpy()
        else:
            dst.value = np.ascontiguousarray(np.transpose(t, [1, 0]))

        dst = tensorrt_llm_llama.layers[idx].post_layernorm.weight
        dst.value = fromfile(
            dir_path, "model.layers." + str(i) + ".post_layernorm.weight.bin"
        )

        t = fromfile(
            dir_path,
            "model.layers." + str(i) + ".mlp.fc.weight." + suffix,
            [n_embd, inter_size // mapping.tp_size],
            w_type,
        )

        if use_smooth_quant:
            tensorrt_llm_llama.layers[idx].mlp.fc.weight.value = sq_trick(
                np.ascontiguousarray(np.transpose(t, [1, 0]))
            )
            set_smoothquant_scale_factors(
                tensorrt_llm_llama.layers[idx].mlp.fc,
                tensorrt_llm_llama.layers[idx].post_layernorm.scale_to_int,
                dir_path,
                "model.layers." + str(i) + ".mlp.fc.",
                [1, inter_size // mapping.tp_size],
                quant_per_token_dyn,
                quant_per_channel,
                rank=mapping.tp_rank,
            )
        elif use_weight_only:
            dst = tensorrt_llm_llama.layers[i].mlp.fc.weight
            (
                processed_torch_weights,
                torch_weight_scales,
            ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
                torch.tensor(t), plugin_weight_only_quant_type
            )
            # workaround for trt not supporting int8 inputs in plugins currently
            dst.value = processed_torch_weights.view(dtype=torch.float32).numpy()
            scales = tensorrt_llm_llama.layers[i].mlp.fc.per_channel_scale
            scales.value = torch_weight_scales.numpy()
        else:
            tensorrt_llm_llama.layers[idx].mlp.fc.weight.value = np.ascontiguousarray(
                np.transpose(t, [1, 0])
            )

        t = fromfile(
            dir_path,
            "model.layers." + str(i) + ".mlp.gate.weight." + suffix,
            [n_embd, inter_size // mapping.tp_size],
            w_type,
        )
        if use_smooth_quant:
            tensorrt_llm_llama.layers[idx].mlp.gate.weight.value = sq_trick(
                np.ascontiguousarray(np.transpose(t, [1, 0]))
            )
            set_smoothquant_scale_factors(
                tensorrt_llm_llama.layers[idx].mlp.gate,
                tensorrt_llm_llama.layers[idx].post_layernorm.scale_to_int,
                dir_path,
                "model.layers." + str(i) + ".mlp.gate.",
                [1, inter_size // mapping.tp_size],
                quant_per_token_dyn,
                quant_per_channel,
                rank=mapping.tp_rank,
            )
        elif use_weight_only:
            dst = tensorrt_llm_llama.layers[i].mlp.gate.weight
            (
                processed_torch_weights,
                torch_weight_scales,
            ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
                torch.tensor(t), plugin_weight_only_quant_type
            )
            # workaround for trt not supporting int8 inputs in plugins currently
            dst.value = processed_torch_weights.view(dtype=torch.float32).numpy()
            scales = tensorrt_llm_llama.layers[i].mlp.gate.per_channel_scale
            scales.value = torch_weight_scales.numpy()
        else:
            tensorrt_llm_llama.layers[idx].mlp.gate.weight.value = np.ascontiguousarray(
                np.transpose(t, [1, 0])
            )

        t = fromfile(
            dir_path,
            "model.layers." + str(i) + ".mlp.proj.weight." + suffix,
            [inter_size // mapping.tp_size, n_embd],
            w_type,
        )
        if use_smooth_quant:
            tensorrt_llm_llama.layers[idx].mlp.proj.weight.value = sq_trick(
                np.ascontiguousarray(np.transpose(t, [1, 0]))
            )
            proj_scale = getattr(
                tensorrt_llm_llama.layers[idx].mlp, "quantization_scaling_factor", None
            )
            set_smoothquant_scale_factors(
                tensorrt_llm_llama.layers[idx].mlp.proj,
                proj_scale,
                dir_path,
                "model.layers." + str(i) + ".mlp.proj.",
                [1, n_embd],
                quant_per_token_dyn,
                quant_per_channel,
            )
            set_smoother(
                tensorrt_llm_llama.layers[idx].mlp.proj,
                dir_path,
                "model.layers." + str(i) + ".mlp.proj",
                [1, inter_size // mapping.tp_size],
                mapping.tp_rank,
            )
        elif use_weight_only:
            dst = tensorrt_llm_llama.layers[i].mlp.proj.weight
            (
                processed_torch_weights,
                torch_weight_scales,
            ) = torch.ops.fastertransformer.symmetric_quantize_last_axis_of_batched_matrix(
                torch.tensor(t), plugin_weight_only_quant_type
            )
            # workaround for trt not supporting int8 inputs in plugins currently
            dst.value = processed_torch_weights.view(dtype=torch.float32).numpy()
            scales = tensorrt_llm_llama.layers[i].mlp.proj.per_channel_scale
            scales.value = torch_weight_scales.numpy()
        else:
            tensorrt_llm_llama.layers[idx].mlp.proj.weight.value = np.ascontiguousarray(
                np.transpose(t, [1, 0])
            )

        if use_int8_kv_cache:
            t = fromfile(
                dir_path,
                "model.layers."
                + str(i)
                + ".attention.query_key_value.scale_y_quant_orig.bin",
                [1],
                np.float32,
            )
            tensorrt_llm_llama.layers[idx].attention.kv_orig_quant_scale.value = 1.0 / t
            tensorrt_llm_llama.layers[idx].attention.kv_quant_orig_scale.value = t

    tok = time.time()
    t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
    tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}")


def load_from_gptq_llama(
    tensorrt_llm_llama, quant_ckpt_path, mapping=Mapping(), dtype="float16"
):
    tensorrt_llm.logger.info("Loading weights from groupwise GPTQ LLaMA safetensors...")
    tik = time.time()

    if quant_ckpt_path.endswith(".safetensors"):
        groupwise_qweight_safetensors = safe_open(
            quant_ckpt_path, framework="pt", device=0
        )
        model_params = {
            key: groupwise_qweight_safetensors.get_tensor(key)
            for key in groupwise_qweight_safetensors.keys()
        }
    elif quant_ckpt_path.endswith(".pt"):
        model_params = torch.load(quant_ckpt_path, map_location=torch.device("cpu"))
    else:
        assert False, "Quantized checkpoint format not supported!"

    def unpack_int32_into_int8(w_packed):
        # Unpack inputs packed in int32/float32 into uint4 and store them in int8 format
        w_packed_int4x2 = w_packed.contiguous().view(torch.uint8)
        w_unpacked = torch.zeros(
            w_packed_int4x2.shape[0], w_packed_int4x2.shape[1] * 2, dtype=torch.int8
        )
        w_unpacked[:, ::2] = w_packed_int4x2 % 16
        w_unpacked[:, 1::2] = w_packed_int4x2 // 16
        return w_unpacked.contiguous()

    def preprocess_groupwise_weight_params(
        weight_name, qweight_int32=None, qzeros_int32=None, scales_fp16=None
    ):
        if weight_name is not None:
            qweight_int32 = model_params[weight_name].cpu()
            qzeros_int32 = model_params[weight_name[:-7] + "qzeros"].cpu()
            scales_fp16 = model_params[weight_name[:-7] + "scales"].cpu()

        UINT4_TO_INT4_FLAG = 1
        GPTQ_FLAG = 1
        packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
        preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm

        qweight_unpacked_int8 = (
            unpack_int32_into_int8(qweight_int32.T).T.contiguous() - 8
        )
        qweight_interleaved = preprocessor(
            packer(qweight_unpacked_int8), torch.quint4x2
        ).view(torch.float32)
        # zeros = zeros * scales
        qzeros_unpacked_int32 = unpack_int32_into_int8(qzeros_int32)
        zeros_x_scales_fp16 = (
            -qzeros_unpacked_int32 + 8 * UINT4_TO_INT4_FLAG - GPTQ_FLAG
        ) * scales_fp16
        zeros_x_scales_fp16 = zeros_x_scales_fp16.half()

        # return processed interleaved weight, original scales and zeros * scales
        return (
            qweight_interleaved.contiguous(),
            scales_fp16.contiguous(),
            zeros_x_scales_fp16.contiguous(),
        )

    layer_ids = [extract_layer_idx(key) for key in groupwise_qweight_safetensors.keys()]
    layer_ids = [int(layer_idx) for layer_idx in layer_ids if layer_idx is not None]
    num_hidden_layers = max(layer_ids) + 1
    num_kv_heads = tensorrt_llm_llama.num_kv_heads
    mha_mode = num_kv_heads == tensorrt_llm_llama.num_heads
    suffixs = ["qweight", "qzeros", "scales"]

    layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size
    layers_range = list(
        range(
            mapping.pp_rank * layers_per_pipeline_stage,
            (mapping.pp_rank + 1) * layers_per_pipeline_stage,
            1,
        )
    )

    for l in layers_range:
        prefix = f"model.layers.{l}.self_attn."
        split_qkv_suf = []

        for suf in suffixs:
            q_part = model_params[prefix + "q_proj." + suf].cpu()
            k_part = model_params[prefix + "k_proj." + suf].cpu()
            v_part = model_params[prefix + "v_proj." + suf].cpu()
            qkv_part = torch.cat([q_part, k_part, v_part], dim=0)
            dim = qkv_part.shape
            qkv_part = qkv_part.reshape(3, dim[0] // 3, dim[1])
            split_qkv = qkv_part.split(dim[1] // mapping.tp_size, dim=2)[
                mapping.tp_rank
            ]
            split_qkv = torch.cat(
                [
                    split_qkv[0, :, :].squeeze(0),
                    split_qkv[1, :, :].squeeze(0),
                    split_qkv[2, :, :].squeeze(0),
                ],
                dim=1,
            )
            split_qkv_suf.append(split_qkv)

        th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
            None, split_qkv_suf[0], split_qkv_suf[1], split_qkv_suf[2]
        )

        idx = l - mapping.pp_rank * layers_per_pipeline_stage
        tensorrt_llm_llama.layers[idx].attention.qkv.qweight.value = th_qweight.numpy()
        tensorrt_llm_llama.layers[idx].attention.qkv.scale.value = th_zero.numpy()
        tensorrt_llm_llama.layers[idx].attention.qkv.zero.value = th_scale.numpy()

    torch_dtype = str_dtype_to_torch(dtype)

    for k, v in model_params.items():
        if isinstance(v, list):
            v = [torch_to_numpy(vv.to(torch_dtype).detach().cpu()) for vv in v]
        else:
            v = torch_to_numpy(v.to(torch_dtype).detach().cpu())
        if "model.embed_tokens.weight" in k:
            if mapping.is_first_pp_rank():
                tensorrt_llm_llama.vocab_embedding.weight.value = v
        elif "model.norm.weight" in k:
            if mapping.is_last_pp_rank():
                tensorrt_llm_llama.ln_f.weight.value = v
        elif "lm_head.weight" in k:
            if mapping.is_last_pp_rank():
                tensorrt_llm_llama.lm_head.weight.value = np.ascontiguousarray(
                    split(v, mapping.tp_size, mapping.tp_rank)
                )
        else:
            layer_idx = extract_layer_idx(k)
            if layer_idx is None:
                continue
            idx = int(layer_idx)
            if idx not in layers_range:
                continue
            idx = idx - mapping.pp_rank * layers_per_pipeline_stage

            if "input_layernorm.weight" in k:
                tensorrt_llm_llama.layers[idx].input_layernorm.weight.value = v
            elif "post_attention_layernorm.weight" in k:
                tensorrt_llm_llama.layers[idx].post_layernorm.weight.value = v
            elif "self_attn.o_proj.qweight" in k:
                split_v_suf = []
                for suf in suffixs:
                    v = model_params[k[:-7] + suf].cpu()
                    split_v = v.split(v.shape[0] // mapping.tp_size, dim=0)[
                        mapping.tp_rank
                    ]
                    split_v_suf.append(split_v)
                th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
                    None, split_v_suf[0], split_v_suf[1], split_v_suf[2]
                )
                tensorrt_llm_llama.layers[
                    idx
                ].attention.dense.qweight.value = th_qweight.numpy()
                tensorrt_llm_llama.layers[
                    idx
                ].attention.dense.scale.value = th_zero.numpy()
                tensorrt_llm_llama.layers[
                    idx
                ].attention.dense.zero.value = th_scale.numpy()
            elif "mlp.up_proj.qweight" in k:
                split_v_suf = []
                for suf in suffixs:
                    v = model_params[k[:-7] + suf].cpu()
                    split_v = v.split(v.shape[1] // mapping.tp_size, dim=1)[
                        mapping.tp_rank
                    ]
                    split_v_suf.append(split_v)
                th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
                    None, split_v_suf[0], split_v_suf[1], split_v_suf[2]
                )
                tensorrt_llm_llama.layers[
                    idx
                ].mlp.gate.qweight.value = th_qweight.numpy()
                tensorrt_llm_llama.layers[idx].mlp.gate.scale.value = th_zero.numpy()
                tensorrt_llm_llama.layers[idx].mlp.gate.zero.value = th_scale.numpy()
            elif "mlp.down_proj.qweight" in k:
                split_v_suf = []
                for suf in suffixs:
                    v = model_params[k[:-7] + suf].cpu()
                    split_v = v.split(v.shape[0] // mapping.tp_size, dim=0)[
                        mapping.tp_rank
                    ]
                    split_v_suf.append(split_v)
                th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
                    None, split_v_suf[0], split_v_suf[1], split_v_suf[2]
                )
                tensorrt_llm_llama.layers[
                    idx
                ].mlp.proj.qweight.value = th_qweight.numpy()
                tensorrt_llm_llama.layers[idx].mlp.proj.scale.value = th_zero.numpy()
                tensorrt_llm_llama.layers[idx].mlp.proj.zero.value = th_scale.numpy()
            elif "mlp.gate_proj.qweight" in k:
                split_v_suf = []
                for suf in suffixs:
                    v = model_params[k[:-7] + suf].cpu()
                    split_v = v.split(v.shape[1] // mapping.tp_size, dim=1)[
                        mapping.tp_rank
                    ]
                    split_v_suf.append(split_v)
                th_qweight, th_zero, th_scale = preprocess_groupwise_weight_params(
                    None, split_v_suf[0], split_v_suf[1], split_v_suf[2]
                )
                tensorrt_llm_llama.layers[idx].mlp.fc.qweight.value = th_qweight.numpy()
                tensorrt_llm_llama.layers[idx].mlp.fc.scale.value = th_zero.numpy()
                tensorrt_llm_llama.layers[idx].mlp.fc.zero.value = th_scale.numpy()

    tok = time.time()
    t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
    tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}")
    return


def load_from_awq_llama(
    tensorrt_llm_llama: LLaMAForCausalLM,
    quant_ckpt_path,
    mapping=Mapping(),
    dtype="float16",
):
    tensorrt_llm.logger.info("Loading weights from groupwise AWQ LLaMA safetensors...")
    tik = time.time()

    if quant_ckpt_path.endswith(".safetensors"):
        groupwise_qweight_safetensors = safe_open(
            quant_ckpt_path, framework="pt", device=0
        )
        awq_llama = {
            key: groupwise_qweight_safetensors.get_tensor(key)
            for key in groupwise_qweight_safetensors.keys()
        }
    elif quant_ckpt_path.endswith(".pt"):
        awq_llama = torch.load(quant_ckpt_path, map_location=torch.device("cpu"))
    else:
        assert False, "Quantized checkpoint format not supported!"

    group_size = (
        awq_llama["model.layers.0.self_attn.o_proj.weight"].numel()
        // awq_llama["model.layers.0.self_attn.o_proj.weight_quantizer._amax"].numel()
    )

    awq_llama_block_names = [
        "input_layernorm.weight",
        "post_attention_layernorm.weight",
    ]

    tensorrt_llm_llama_block_names = [
        "input_layernorm.weight",
        "post_layernorm.weight",
    ]

    getattr(tensorrt_llm_llama, "quant_mode", QuantMode(0))

    packer = torch.ops.fastertransformer.pack_int8_tensor_to_packed_int4
    preprocessor = torch.ops.fastertransformer.preprocess_weights_for_mixed_gemm
    torch_dtype = str_dtype_to_torch(dtype)

    def AWQ_quantize_pack_preprocess(weight, scale):
        scale = scale.repeat_interleave(group_size, dim=0)
        weight = weight / scale
        qweight_int8 = torch.clamp(torch.round(weight.cuda()).char(), -8, 7)
        int4_weight = packer(qweight_int8.cpu())
        int4_weight = preprocessor(int4_weight, torch.quint4x2)
        return int4_weight.view(torch.float32).cpu().numpy()

    def process_and_assign_weight(awq_llama, mPrefix, mOp, tp_dim=0):
        weight = awq_llama[mPrefix + ".weight"].T.contiguous()
        [k, n] = weight.shape
        weight = weight.split(weight.shape[tp_dim] // mapping.tp_size, dim=tp_dim)[
            mapping.tp_rank
        ]
        amax = (
            awq_llama[mPrefix + ".weight_quantizer._amax"]
            .reshape((n, int(k / group_size)))
            .T.contiguous()
        )
        amax = amax.split(amax.shape[tp_dim] // mapping.tp_size, dim=tp_dim)[
            mapping.tp_rank
        ]
        pre_quant_scale = awq_llama[
            mPrefix + ".input_quantizer._pre_quant_scale"
        ].reshape((1, k))
        if tp_dim == 0:
            pre_quant_scale = pre_quant_scale.split(k // mapping.tp_size, dim=1)[
                mapping.tp_rank
            ]
        scale = amax / 8.0
        mOp.qweight.value = AWQ_quantize_pack_preprocess(weight, scale)
        mOp.scale.value = scale.to(torch_dtype).cpu().numpy()
        mOp.pre_quant_scale.value = pre_quant_scale.to(torch_dtype).cpu().numpy()

    def deSmooth(weight, pre_quant_scale):
        [k, n] = weight.shape
        pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1, 0).contiguous()
        weight = weight * pre_quant_scale
        return weight

    def reSmooth(weight, pre_quant_scale):
        [k, n] = weight.shape
        pre_quant_scale = pre_quant_scale.repeat((n, 1)).transpose(1, 0).contiguous()
        weight = weight / pre_quant_scale
        return weight

    def get_scale(weight):
        weight = weight.T.contiguous()
        [n, k] = weight.shape
        weight = weight.reshape(n, int(k / group_size), group_size)
        weight = torch.abs(weight.reshape(-1, group_size))
        amax, idx = weight.max(1)
        amax = amax.reshape(n, int(k / group_size)).T.contiguous()
        return amax / 8

    def reSmooth_and_get_scale(weight, pre_quant_scale, avg_pre_quant_scale):
        weight = deSmooth(weight, pre_quant_scale)
        weight = reSmooth(weight, avg_pre_quant_scale)
        scale = get_scale(weight)
        return weight, scale

    def process_and_assign_qkv_weight(awq_llama, prefix, mOp):
        q_weight = awq_llama[prefix + "self_attn.q_proj.weight"].T.contiguous()
        k_weight = awq_llama[prefix + "self_attn.k_proj.weight"].T.contiguous()
        v_weight = awq_llama[prefix + "self_attn.v_proj.weight"].T.contiguous()
        k = q_weight.shape[0]

        q_weight = q_weight.split(q_weight.shape[1] // mapping.tp_size, dim=1)[
            mapping.tp_rank
        ]
        k_weight = k_weight.split(k_weight.shape[1] // mapping.tp_size, dim=1)[
            mapping.tp_rank
        ]
        v_weight = v_weight.split(v_weight.shape[1] // mapping.tp_size, dim=1)[
            mapping.tp_rank
        ]

        q_pre_quant_scale = awq_llama[
            prefix + "self_attn.q_proj.input_quantizer._pre_quant_scale"
        ].reshape((1, k))
        k_pre_quant_scale = awq_llama[
            prefix + "self_attn.k_proj.input_quantizer._pre_quant_scale"
        ].reshape((1, k))
        v_pre_quant_scale = awq_llama[
            prefix + "self_attn.v_proj.input_quantizer._pre_quant_scale"
        ].reshape((1, k))

        qkv_pre_quant_scale = (
            q_pre_quant_scale + k_pre_quant_scale + v_pre_quant_scale
        ) / 3.0
        q_weight, q_scale = reSmooth_and_get_scale(
            q_weight, q_pre_quant_scale, qkv_pre_quant_scale
        )
        k_weight, k_scale = reSmooth_and_get_scale(
            k_weight, k_pre_quant_scale, qkv_pre_quant_scale
        )
        v_weight, v_scale = reSmooth_and_get_scale(
            v_weight, v_pre_quant_scale, qkv_pre_quant_scale
        )

        qkv_weights = torch.cat((q_weight, k_weight, v_weight), dim=1)
        qkv_scale = torch.cat((q_scale, k_scale, v_scale), dim=1)

        mOp.pre_quant_scale.value = qkv_pre_quant_scale.to(torch_dtype).cpu().numpy()
        mOp.qweight.value = AWQ_quantize_pack_preprocess(qkv_weights, qkv_scale)
        mOp.scale.value = qkv_scale.to(torch_dtype).cpu().numpy()

    # Check if we need to pad vocab
    v = awq_llama.get("model.embed_tokens.weight")
    [vocab_size, k] = v.shape
    pad_vocab = False
    pad_vocab_size = vocab_size
    if vocab_size % 64 != 0:
        pad_vocab = True
        pad_vocab_size = int((vocab_size + 63) / 64) * 64
    if pad_vocab:
        new_v = torch.zeros([pad_vocab_size, k])
        new_v[:vocab_size, :] = v
        v = new_v
    if mapping.is_first_pp_rank():
        tensorrt_llm_llama.vocab_embedding.weight.value = (
            v.to(torch_dtype).cpu().numpy()
        )

    layer_ids = [extract_layer_idx(key) for key in awq_llama.keys()]
    layer_ids = [int(layer_idx) for layer_idx in layer_ids if layer_idx is not None]

    num_hidden_layers = max(layer_ids) + 1
    layers_per_pipeline_stage = num_hidden_layers // mapping.pp_size
    layers_range = list(
        range(
            mapping.pp_rank * layers_per_pipeline_stage,
            (mapping.pp_rank + 1) * layers_per_pipeline_stage,
            1,
        )
    )

    for layer_idx in layers_range:
        prefix = "model.layers." + str(layer_idx) + "."
        tensorrt_llm.logger.info(f"Process weights in layer: {layer_idx}")
        for idx, awq_attr in enumerate(awq_llama_block_names):
            v = awq_llama[prefix + awq_attr]
            layer = attrgetter(tensorrt_llm_llama_block_names[idx])(
                tensorrt_llm_llama.layers[layer_idx]
            )
            setattr(layer, "value", v.to(torch_dtype).cpu().numpy())

        # Attention QKV Linear
        # concatenate the Q, K, V layers weights.
        process_and_assign_qkv_weight(
            awq_llama, prefix, tensorrt_llm_llama.layers[layer_idx].attention.qkv
        )

        # Attention Dense (out_proj) Linear
        mPrefix = prefix + "self_attn.o_proj"
        mOp = tensorrt_llm_llama.layers[layer_idx].attention.dense
        process_and_assign_weight(awq_llama, mPrefix, mOp, 0)

        # MLP up_proj (mlp.gate) Linear
        mPrefix = prefix + "mlp.up_proj"
        mOp = tensorrt_llm_llama.layers[layer_idx].mlp.gate
        process_and_assign_weight(awq_llama, mPrefix, mOp, 1)

        # MLP down_proj (mlp.proj) Linear
        mPrefix = prefix + "mlp.down_proj"
        mOp = tensorrt_llm_llama.layers[layer_idx].mlp.proj
        process_and_assign_weight(awq_llama, mPrefix, mOp, 0)

        # MLP gate_proj (mlp.fc) Linear
        mPrefix = prefix + "mlp.gate_proj"
        mOp = tensorrt_llm_llama.layers[layer_idx].mlp.fc
        process_and_assign_weight(awq_llama, mPrefix, mOp, 1)

    v = awq_llama["model.norm.weight"]
    if mapping.is_last_pp_rank():
        tensorrt_llm_llama.ln_f.weight.value = v.to(torch_dtype).cpu().numpy()

    # lm_head
    if pad_vocab:
        weight = awq_llama["lm_head.weight"]
        [vocab_size, k] = weight.shape
        new_weight = torch.zeros([pad_vocab_size, k])
        new_weight[:vocab_size, :] = weight
        new_weight = new_weight.T.contiguous()
        amax = awq_llama["lm_head.weight_quantizer._amax"].reshape(
            [vocab_size, k // group_size]
        )
        new_amax = torch.ones([pad_vocab_size, k // group_size])
        new_amax[:vocab_size, :] = amax
        new_amax = new_amax.T.contiguous()
        new_scale = new_amax / 8
        tensorrt_llm_llama.lm_head.qweight.value = AWQ_quantize_pack_preprocess(
            new_weight, new_scale
        )
        tensorrt_llm_llama.lm_head.scale.value = new_scale.to(torch_dtype).cpu().numpy()
        tensorrt_llm_llama.lm_head.pre_quant_scale.value = (
            awq_llama["lm_head.input_quantizer._pre_quant_scale"]
            .to(torch_dtype)
            .cpu()
            .numpy()
        )
    else:
        mPrefix = "lm_head"
        mOp = tensorrt_llm_llama.lm_head
        if mapping.is_last_pp_rank():
            process_and_assign_weight(awq_llama, mPrefix, mOp, 1)

    tok = time.time()
    t = time.strftime("%H:%M:%S", time.gmtime(tok - tik))
    tensorrt_llm.logger.info(f"Weights loaded. Total time: {t}")
